from typing import Any, Dict, Optional

from axelrod.action import Action

from axelrod.evolvable_player import (
    EvolvablePlayer,
    InsufficientParametersError,
    copy_lists,
    crossover_lists,
)

from axelrod.player import Player

C, D = Action.C, Action.D

def is_stochastic_matrix(m, ep=1e-8) -> bool:
    """Checks that the matrix m (a list of lists) is a stochastic matrix."""
    for i in range(len(m)):
        for j in range(len(m[i])):
            if (m[i][j] < 0) or (m[i][j] > 1):
                return False
        s = sum(m[i])
        if abs(1.0 - s) > ep:
            return False
    return True

def normalize_vector(vec):
    s = sum(vec)
    vec = [v / s for v in vec]
    return vec

def mutate_row(row, mutation_probability, rng):
    """, crossover_lists_of_lists
    Given a row of probabilities, randomly change each entry with probability
    `mutation_probability` (a value between 0 and 1).  If changing, then change
    by a value randomly (uniformly) chosen from [-0.25, 0.25] bounded by 0 and
    100%.
    """
    randoms = rng.random(len(row))
    for i in range(len(row)):
        if randoms[i] < mutation_probability:
            ep = rng.uniform(-1, 1) / 4
            row[i] += ep
            if row[i] < 0:
                row[i] = 0
            if row[i] > 1:
                row[i] = 1
    return row

class SimpleHMM(object):
    """Implementation of a basic Hidden Markov Model. We assume that the
    transition matrix is conditioned on the opponent's last action, so there
    are two transition matrices. Emission distributions are stored as Bernoulli
    probabilities for each state. This is essentially a stochastic FSM.

    https://en.wikipedia.org/wiki/Hidden_Markov_model
    """

    def __init__(
        self,
        transitions_C,
        transitions_D,
        emission_probabilities,
        initial_state,
    ) -> None:
        """
        Params
        ------
        transitions_C and transitions_D are square stochastic matrices:
            lists of lists with all values in [0, 1] and rows that sum to 1.
        emission_probabilities is a vector of values in [0, 1]
        initial_state is an element of range(0, len(emission_probabilities))
        """
        self.transitions_C = transitions_C
        self.transitions_D = transitions_D
        self.emission_probabilities = emission_probabilities
        self.state = initial_state
        self._cache_C = dict()  # type: Dict[int, int]
        self._cache_D = dict()  # type: Dict[int, int]
        self._cache_deterministic_transitions()
        # Random generator will be set by parent strategy
        self._random = None  # type: Any

    def _cache_deterministic_transitions(self):
        """Cache deterministic transitions to avoid unnecessary random draws."""
        # If 1 is in the transition vector, it's deterministic. Just pick it out.
        # By caching we avoid repeated searches.
        for state in range(len(self.transitions_C)):
            if 1 in self.transitions_C[state]:
                self._cache_C[state] = self.transitions_C[state].index(1)
        for state in range(len(self.transitions_D)):
            if 1 in self.transitions_D[state]:
                self._cache_D[state] = self.transitions_D[state].index(1)

    def is_well_formed(self) -> bool:
        """
        Determines if the HMM parameters are well-formed:
            - Both matrices are stochastic
            - Emissions probabilities are in [0, 1]
            - The initial state is valid.
        """
        if not is_stochastic_matrix(self.transitions_C):
            return False
        if not is_stochastic_matrix(self.transitions_D):
            return False
        for p in self.emission_probabilities:
            if (p < 0) or (p > 1):
                return False
        if self.state not in range(0, len(self.emission_probabilities)):
            return False
        return True

    def __eq__(self, other: Player) -> bool:
        """Equality of two HMMs"""
        check = True
        for attr in [
            "transitions_C",
            "transitions_D",
            "emission_probabilities",
            "state",
        ]:
            check = check and getattr(self, attr) == getattr(other, attr)
        return check

    def move(self, opponent_action: Action) -> Action:
        """Changes state and computes the response action.

        Parameters
            opponent_action: Axelrod.Action
                The opponent's last action.
        """
        # Choose next state.
        if opponent_action == C:
            try:
                next_state = self._cache_C[self.state]
            except KeyError:
                num_states = len(self.emission_probabilities)
                next_state = self._random.choice(
                    num_states, 1, p=self.transitions_C[self.state]
                )[0]
        else:
            try:
                next_state = self._cache_D[self.state]
            except KeyError:
                num_states = len(self.emission_probabilities)
                next_state = self._random.choice(
                    num_states, 1, p=self.transitions_D[self.state]
                )[0]

        self.state = next_state
        # Choose action to emit.
        p = self.emission_probabilities[self.state]
        if p == 0:
            return D
        if p == 1:
            return C
        action = self._random.random_choice(p)
        return action